{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BGE-EN-ICL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we will go through BGE-EN-ICL, an LLM based embedding model with both strong zero-shot and few-shot embedding capability."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0.Installation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the required packages in your environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -U transformers FlagEmbedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. BGE-EN-ICL structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/share/project/xzy/Envs/ft/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  9.94it/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch, os\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"BAAI/bge-en-icl\")\n",
    "raw_model = AutoModel.from_pretrained(\"BAAI/bge-en-icl\")\n",
    "\n",
    "sentences = [\"embedding\", \"I love machine learning and nlp\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Different from the previous BGE embedding models which are encoder only models, BGE-EN-ICL use decoder only LLM, Mistral-7B, as the base model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MistralModel(\n",
       "  (embed_tokens): Embedding(32003, 4096)\n",
       "  (layers): ModuleList(\n",
       "    (0-31): 32 x MistralDecoderLayer(\n",
       "      (self_attn): MistralSdpaAttention(\n",
       "        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "        (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "        (rotary_emb): MistralRotaryEmbedding()\n",
       "      )\n",
       "      (mlp): MistralMLP(\n",
       "        (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "        (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "        (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "        (act_fn): SiLU()\n",
       "      )\n",
       "      (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n",
       "      (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)\n",
       "    )\n",
       "  )\n",
       "  (norm): MistralRMSNorm((4096,), eps=1e-05)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. New Pooling Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BERT-like encoder only networks are considered with strong capacity for representation learning because of their bidirectional attention structure. Some previous work replace unidirectional attention with bidirectional attention during the embedding training phase. But this might creates a mismatch with the model's pre-training design, which could potentially undermine its in-context learning and generative properties.\n",
    "\n",
    "Thus BGE-EN-ICL introduces a [EOS] token's output embedding to address this issue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[    0,     0,     0,     0,     0,     0,     1, 28643,     2],\n",
       "        [    1,   315,  2016,  5599,  5168,   304,   307, 12312,     2]]), 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1],\n",
       "        [1, 1, 1, 1, 1, 1, 1, 1, 1]])}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(\n",
    "    sentences,\n",
    "    padding=True,\n",
    "    return_tensors='pt',\n",
    ")\n",
    "inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 9, 4096])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "last_hidden_state = raw_model(**inputs, return_dict=True).last_hidden_state\n",
    "last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The last token/[EOS] pooling method can be described as:\n",
    "\n",
    "Given the tokenized input sequence $T: [\\text{BOS}], t_1, ..., t_N$ is sent into the LLM:\n",
    "$$h_t = \\text{LLM}(T)[\\text{EOS}]$$\n",
    "where $h_t$ represents the text embedding taken from the output embedding of the special token [EOS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def last_token_pool(last_hidden_states: torch.Tensor,\n",
    "                    attention_mask: torch.Tensor) -> torch.Tensor:\n",
    "    \n",
    "    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])\n",
    "    if left_padding:\n",
    "        return last_hidden_states[:, -1]\n",
    "    else:\n",
    "        sequence_lengths = attention_mask.sum(dim=1) - 1\n",
    "        batch_size = last_hidden_states.shape[0]\n",
    "        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4096])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings = last_token_pool(\n",
    "    last_hidden_state,  \n",
    "    attention_mask=inputs['attention_mask']\n",
    ")\n",
    "embeddings.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. In-Context Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "BGE-EN-ICL integrate strong in-context learning of LLM into embedding model while still persisting strong zero-shot embedding capability."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For zero-shot inference, it's exactly same to BGE v1&1.5. For few-shot inference, use the following way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "examples = [\n",
    "    {\n",
    "        'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',\n",
    "        'query': 'what is a virtual interface',\n",
    "        'response': \"A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes.\"\n",
    "    },\n",
    "    {\n",
    "        'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',\n",
    "        'query': 'causes of back pain in female for a week',\n",
    "        'response': \"Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management.\"\n",
    "    }\n",
    "]\n",
    "\n",
    "queries = [\"how much protein should a female eat\", \"summit define\"]\n",
    "documents = [\n",
    "    \"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.\",\n",
    "    \"Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments.\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  4.59it/s]\n",
      "pre tokenize: 100%|██████████| 1/1 [00:00<00:00, 501.41it/s]\n",
      "You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.6064 0.302 ]\n",
      " [0.257  0.5366]]\n"
     ]
    }
   ],
   "source": [
    "from FlagEmbedding import FlagICLModel\n",
    "\n",
    "model = FlagICLModel('BAAI/bge-en-icl', \n",
    "                     examples_for_task=examples,  # set `examples_for_task=None` to use model without examples\n",
    "                     examples_instruction_format=\"<instruct>{}\\n<query>{}\\n<response>{}\", # specify the format to use examples_for_task\n",
    "                     devices=[0],\n",
    "                    )\n",
    "\n",
    "embeddings_1 = model.encode_queries(queries)\n",
    "embeddings_2 = model.encode_corpus(documents)\n",
    "similarity = embeddings_1 @ embeddings_2.T\n",
    "\n",
    "print(similarity)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ft",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
